import argparse

from torch.utils.data import DataLoader
import numpy as np
import torch
import torch.nn as nn

from torch.autograd import Variable
from torch import nn, optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import random_split

from process import get_data
from dataset import AgDataset
from LTCN import ITNT


def collate_fn(data: list):
    review = []
    label = []
    for datum in data:
        review.append(datum[0])
        label.append(float(datum[1]))
    return review, np.array(label)


if __name__ == '__main__':

    args = argparse.ArgumentParser()
    # args.add_argument('--mode', type=str, default='train')
    args.add_argument('--mode', type=str, default='test')
    # User options
    # args.add_argument('--epochs', type=int, default=30)
    args.add_argument('--epochs', type=int, default=50)
    args.add_argument('--batch', type=int, default=20)
    args.add_argument('--strmaxlen', type=int, default=1000)
    args.add_argument('--embedding', type=int, default=8)
    args.add_argument('--lr', type=float, default=1e-4)
    args.add_argument('--convchannel', type=int, default=200)
    args.add_argument('--tsize', type=int, default=1000)
    args.add_argument('--lrstep', type=int, default=1000)
    args.add_argument('--level', type=int, default=3)
    args.add_argument('--attention', type=bool, default=False)
    config = args.parse_args()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    path_info = "../../../dataset/train_info.xlsx"
    path_label = "../../../dataset/train_label.xlsx"

    data_train_info, data_train_label = get_data(path_info, path_label)

    dataset = AgDataset(data_train_info, data_train_label, 500)

    train_size = int(len(dataset) * 0.7)
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    model = ITNT(8, 500, channel_size=200, T_size=1000, level=3, attention=False)
    # model = model.to(device)
    if config.mode == 'train':
        train_loader = DataLoader(dataset=train_dataset, batch_size=20, shuffle=True, collate_fn=collate_fn,
                                  num_workers=2)

    elif config.mode == 'test':
        # if config.mode == 'test':
        model.load_state_dict(torch.load('model.pkl'))
        test_loader = DataLoader(dataset=test_dataset, batch_size=20, shuffle=True, collate_fn=collate_fn,
                                 num_workers=2)

if config.mode == 'train':
    params = filter(lambda p: p.requires_grad, model.parameters())
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(params, lr=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1000, factor=0.95,
                                                           verbose=False, mode='min')

    total_batch = len(train_loader)
    for epoch in range(20):
        avg_loss = 0.0
        model.train()
        for i, data in enumerate(train_loader):
            data_info, data_label = data

            # data_info = data_info.to(device)
            # data_label = data_label.to(device)
            outputs = model(data_info)
            label_vars = Variable(torch.from_numpy(data_label).long())
            # label_vars = label_vars.to(device)

            loss = criterion(outputs, label_vars)
            # loss = loss.to(device)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            scheduler.step(loss.item())
            avg_loss += loss.item()

            if i == 0 or i % (total_batch / 10) == 0:
                print('Batch : ', i + 1, '/', total_batch, ', Loss in this minibatch: ', loss.item())

        print('epoch:', epoch, ' train_loss:', float(avg_loss / total_batch))
        #torch.save(model.state_dict(), 'model.pkl')
        train_loss = float(avg_loss / total_batch)
        f = open("train_save.txt", "a+")
        print(file=f)
        print(train_loss, file=f)
        f.close()
    torch.save(model.state_dict(), 'model.pkl')

if config.mode == 'test':
    model.load_state_dict(torch.load('model.pkl'))
    model.eval()
    wrong = 0
    correct = 0
    parms = filter(lambda p: p.requires_grad, model.parameters())
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(parms, lr=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1000, factor=0.95,
                                                           verbose=False, mode='min')

    for i, data in enumerate(test_loader):
        data_info, data_label = data
        preds = model(data_info)
        label_vars = Variable(torch.from_numpy(data_label).long())
        loss = criterion(preds, label_vars)
        optimizer.zero_grad()
        # loss.backward()
        optimizer.step()
        for j in range(len(preds)):
            print(preds.argmax(1)[j])
            if preds.argmax(1)[j] == data_label[j]:
                correct += 1
            else:
                wrong += 1
    print('test loss: %.4f' % loss.data.numpy(), "test accuracy: ", float(correct) / (wrong + correct))

    test_accuracy = float(correct) / (wrong + correct)
    test_loss = loss.data.numpy()
    f = open("test_save.txt", "a+")
    print(file=f)
    print(test_loss, test_accuracy, file=f)
    f.close()
